import random
import process_data_set
import importlib
import os
import cv2
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import cv2
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import time
from skimage.measure import shannon_entropy
from ml_tools_utils import utils
utils.pandas_config(pd)
utils.plt_config(plt)
sns.set_theme(style="darkgrid", palette="pastel")
plt.style.use("fivethirtyeight")
importlib.reload(process_data_set)
process_data_set.download_ds(process_data_set.TEMP_DATASET_NAME)
dups = process_data_set.find_duplicates(process_data_set.UNPROC_DATASET_LOC)
dups
| Class | Duplicate Count | Total Images | Proportion | |
|---|---|---|---|---|
| 0 | Agaricus | 2 | 353 | 0.005666 |
| 1 | Amanita | 2 | 750 | 0.002667 |
| 2 | Boletus | 2 | 1073 | 0.001864 |
| 3 | Cortinarius | 2 | 836 | 0.002392 |
| 4 | Entoloma | 0 | 364 | 0.000000 |
| 5 | Hygrocybe | 1 | 316 | 0.003165 |
| 6 | Lactarius | 63 | 1563 | 0.040307 |
| 7 | Russula | 4 | 1147 | 0.003487 |
| 8 | Suillus | 0 | 311 | 0.000000 |
| 9 | Total | 76 | 6713 | 0.011321 |
There are some corrupt and unreadable images in the dataset that also need to be removed:
process_data_set.verify_and_clean_images(process_data_set.UNPROC_DATASET_LOC)
Removing corrupt image: dataset_temp\Mushrooms\Russula\092_43B354vYxm8.jpg due to image file is truncated (92 bytes not processed)
[WindowsPath('dataset_temp/Mushrooms/Russula/092_43B354vYxm8.jpg')]
Image Analysis¶
def get_image_paths(data_dir):
image_paths = []
for subdir, _, files in os.walk(data_dir):
for file in files:
if file.endswith(('.png', '.jpg', '.jpeg')):
image_paths.append(os.path.join(subdir, file))
return image_paths
def process_image(image_path, bins=32):
image = cv2.imread(image_path)
if image is None:
return None
color_type = 'Unknown'
if len(image.shape) == 2:
color_type = 'Grayscale'
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif len(image.shape) == 3:
if image.shape[2] == 3:
color_type = 'Color'
else:
color_type = f'Other ({image.shape[2]} channels)'
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_shape = image.shape
aspect_ratio = original_shape[1] / original_shape[0]
image = cv2.resize(image, (100, 100)) # downsample image to reduce size
class_name = os.path.basename(os.path.dirname(image_path))
color_distributions = {class_name: {'R': [], 'G': [], 'B': []}}
for channel, color in enumerate(['R', 'G', 'B']):
hist = cv2.calcHist([image], [channel], None, [bins], [0, 256])
hist = hist.flatten() / hist.sum()
color_distributions[class_name][color] = hist
variance = np.var(image, axis=(0, 1)).mean()
unique_colors = len(np.unique(image.reshape(-1, image.shape[2]), axis=0))
entropy = shannon_entropy(image)
return color_distributions, color_type, original_shape, aspect_ratio, variance, unique_colors, entropy, image_path
def merge_color_distributions(distributions_list):
class_counts = {}
merged_distributions = {}
for result in distributions_list:
if result is None:
continue
distributions, _, _, _, _, _, _, _ = result
for class_name, color_dist in distributions.items():
if class_name not in merged_distributions:
merged_distributions[class_name] = {
'R': np.zeros_like(color_dist['R']),
'G': np.zeros_like(color_dist['G']),
'B': np.zeros_like(color_dist['B'])
}
class_counts[class_name] = 0
for color in ['R', 'G', 'B']:
merged_distributions[class_name][color] += color_dist[color]
class_counts[class_name] += 1
for class_name in merged_distributions:
for color in ['R', 'G', 'B']:
merged_distributions[class_name][color] /= class_counts[class_name]
return merged_distributions
def get_color_distributions(image_paths, max_workers=None):
start_time = time.time()
if max_workers is None:
max_workers = os.cpu_count()
print(f"Running on {max_workers} workers")
color_distributions_list = []
color_types = []
shapes = []
aspect_ratios = []
variances = []
unique_colors = []
entropies = []
image_paths_list = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(process_image, image_path): image_path for image_path in image_paths}
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"):
result = future.result()
if result is not None:
color_distributions_list.append(result)
color_types.append(result[1])
shapes.append(result[2])
aspect_ratios.append(result[3])
variances.append(result[4])
unique_colors.append(result[5])
entropies.append(result[6])
image_paths_list.append(result[7])
color_distributions = merge_color_distributions(color_distributions_list)
elapsed_time = time.time() - start_time
print(f"Total processing time: {elapsed_time:.2f} seconds")
return color_distributions, color_types, shapes, aspect_ratios, variances, unique_colors, entropies, image_paths_list
def summarize_image_types(image_paths, color_types):
summary = pd.DataFrame({'image_path': image_paths, 'color_type': color_types})
summary_table = summary['color_type'].value_counts().reset_index()
summary_table.columns = ['Color Type', 'Count']
return summary_table
def summarize_dimensions(image_paths, shapes):
summary = pd.DataFrame({'image_path': image_paths, 'shape': shapes})
width_summary = summary['shape'].apply(lambda x: x[1]).value_counts().reset_index()
width_summary.columns = ['Width', 'Count']
height_summary = summary['shape'].apply(lambda x: x[0]).value_counts().reset_index()
height_summary.columns = ['Height', 'Count']
aspect_ratio_summary = summary['shape'].apply(lambda x: x[1] / x[0]).value_counts().reset_index()
aspect_ratio_summary.columns = ['Aspect Ratio', 'Count']
return width_summary, height_summary, aspect_ratio_summary
def summarize_color_metrics(image_paths, variances, unique_colors, entropies):
summary = pd.DataFrame({
'image_path': image_paths,
'variance': variances,
'unique_colors': unique_colors,
'entropy': entropies
})
return summary
def plot_color_distributions(color_distributions, bins=32):
global_min, global_max = 0, 0
for class_name, distributions in color_distributions.items():
for color in ['R', 'G', 'B']:
max_value = max(distributions[color])
if max_value > global_max:
global_max = max_value
for class_name, distributions in color_distributions.items():
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for i, (color, ax) in enumerate(zip(['R', 'G', 'B'], axes)):
ax.bar(range(bins), distributions[color], color=color.lower(), alpha=0.7)
ax.set_title(f'{class_name} - {color} Channel Distribution')
ax.set_xlabel('Intensity')
ax.set_ylabel('Density')
ax.set_ylim(0, global_max)
plt.tight_layout()
plt.show()
def plot_filtered_images_by_entropy(filtered_image_paths, filtered_entropies, images_per_row=4):
num_images = len(filtered_image_paths)
num_rows = (num_images + images_per_row - 1) // images_per_row
fig, axes = plt.subplots(num_rows, images_per_row, figsize=(20, 5 * num_rows))
axes = axes.flatten()
for ax, (image_path, entropy) in zip(axes, zip(filtered_image_paths, filtered_entropies)):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
class_name = os.path.basename(os.path.dirname(image_path))
file_name = os.path.basename(image_path)
ax.imshow(image)
ax.set_title(f'{class_name}/{file_name}\nEntropy: {entropy:.2f}')
ax.axis('off')
for ax in axes[num_images:]:
ax.axis('off')
plt.tight_layout()
plt.show()
image_paths = get_image_paths(process_data_set.UNPROC_DATASET_LOC)
color_distributions, color_types, shapes, aspect_ratios, variances, unique_colors, entropies, image_paths_list = get_color_distributions(
image_paths)
Running on 28 workers
Processing images: 100%|██████████| 6637/6637 [00:26<00:00, 248.23it/s]
Total processing time: 26.97 seconds
Image color summary¶
There seem to be no grayscale images and all images have 3 color channels.
summary_table = summarize_image_types(image_paths, color_types)
summary_table
| Color Type | Count | |
|---|---|---|
| 0 | Color | 6637 |
Dimensions¶
There seems to be alot of variance between images sizes and aspect ratios this is a significant concern because models like ResNet, require fixed input sizes (224x224 in this case)
We are using FastAI's 'ImageDataLoaders' which handles the resizing and add padding to samples which are not square.
width_summary, height_summary, aspect_ratio_summary = summarize_dimensions(image_paths, shapes)
width_stats = width_summary["Width"].describe()
height_stats = height_summary["Height"].describe()
aspect_ratio_stats = aspect_ratio_summary["Aspect Ratio"].describe()
summary_df = pd.DataFrame({
'Width': width_stats,
'Height': height_stats,
'Aspect Ratio': aspect_ratio_stats
})
display(summary_df)
| Width | Height | Aspect Ratio | |
|---|---|---|---|
| count | 363.000000 | 505.000000 | 1156.000000 |
| mean | 709.586777 | 623.914851 | 1.325247 |
| std | 207.637554 | 183.902749 | 0.237855 |
| min | 259.000000 | 152.000000 | 0.561250 |
| 25% | 572.500000 | 487.000000 | 1.226641 |
| 50% | 702.000000 | 613.000000 | 1.336744 |
| 75% | 797.500000 | 749.000000 | 1.462122 |
| max | 1280.000000 | 1024.000000 | 2.857143 |
Color Variance and Entropy¶
We'll further analyze the images using these metrics:
Average variance of color channels in the all images:
- Variance = 0: All pixels in the image have the same color.
- High Variance: Indicates images with diverse color pixels.
Number of unique colors in each image
Entropy (shannon_entropy).
- Scale: 0 to log2(N), where N is the number of possible pixel values (0 to 8 for 256 grayscale values).
- Min Entropy = 0: Perfectly uniform image (single color).
- High Entropy: Indicates images with a wide variety of colors and patterns.
- Scale: 0 to log2(N), where N is the number of possible pixel values (0 to 8 for 256 grayscale values).
image_entropy_summary = summarize_color_metrics(image_paths, variances, unique_colors, entropies)
variance_summary = image_entropy_summary['variance'].describe()
unique_colors_summary = image_entropy_summary['unique_colors'].describe()
entropy_summary = image_entropy_summary['entropy'].describe()
summary_df = pd.DataFrame({
'Variance': variance_summary,
'Unique Color': unique_colors_summary,
'Entropy': entropy_summary
})
display(summary_df)
| Variance | Unique Color | Entropy | |
|---|---|---|---|
| count | 6637.000000 | 6637.000000 | 6637.000000 |
| mean | 3339.696198 | 9117.627844 | 7.575777 |
| std | 1040.674133 | 858.859522 | 0.289316 |
| min | 0.000000 | 1.000000 | 0.000000 |
| 25% | 2629.536395 | 8884.000000 | 7.488475 |
| 50% | 3218.788643 | 9367.000000 | 7.625352 |
| 75% | 3928.063424 | 9661.000000 | 7.732022 |
| max | 11091.372221 | 9968.000000 | 7.979296 |
This information will be used to filter out the samples which might not be useful for image analysis
image_entropy_summary = summarize_color_metrics(image_paths, variances, unique_colors, entropies)
image_entropy_summary['class'] = image_entropy_summary['image_path'].apply(
lambda x: os.path.basename(os.path.dirname(x)))
g = sns.displot(
image_entropy_summary, x="entropy", row="class", binwidth=0.1, height=3, aspect=2.5,
facet_kws=dict(margin_titles=True)
).set(xlim=(6.5, None))
g.fig.suptitle('Entropy Distribution by Class', y=1.01)
Text(0.5, 1.01, 'Entropy Distribution by Class')
invalid_image_paths = []
Samples with Very Low Entropy¶
The images below have very low entropy (i.e. in the bottom 0.5th percentile).
We can see that while some images are actual mushrooms that were photographed against a single color background (probably in a studio etc.) the image with 0.0 entropy is not valid. Additionally, we can see that one of the images is not an actual mushroom but just a random pattern (unfortunately our color variance based approach is not particularly useful at identifying images as such unless the pattern is very simple)
entropy_thresh = np.percentile(image_entropy_summary["entropy"], 0.5, axis=0)
filtered_image_paths = [path for path, entropy in zip(image_paths_list, entropies) if entropy < entropy_thresh]
filtered_entropies = [entropy for entropy in entropies if entropy < entropy_thresh]
plot_filtered_images_by_entropy(filtered_image_paths, filtered_entropies, images_per_row=5)
invalid_image_paths.append("0051_rBIC-Uy9KzI.jpg")
invalid_image_paths.append("0127_1R8TZJseXgY.jpg")
High Entropy Images¶
These images tend to show very colorful mushrooms and have colorful varying forest backgrounds.
top_n = 25
sorted_images_by_entropy = sorted(zip(image_paths_list, entropies), key=lambda x: x[1], reverse=True)[:top_n]
filtered_image_paths, filtered_entropies = zip(*sorted_images_by_entropy)
plot_filtered_images_by_entropy(filtered_image_paths, filtered_entropies, 5)
Color Chanel Distribution by Class¶
These plots show the normalized intensity (0 - 255) distributions of color channel by class. The Y show the normalized frequency (density) relative to all color channels (based on highest individual value for any channel).
The charts are made by generating a histogram for each image, normalizing it (normalization process maintains the shape of the histogram, meaning the relative distribution of pixel intensities is preserved) All histograms in the class are then averaged.
As we would expect greend and red tend to be dominant in most images which reflects the color of most images type and forest floor background they tend to be photographed against.
plot_color_distributions(color_distributions)
Data Sampling¶
Our next set is to split the dataset into test and training samples roughly trying to maintain an around 15% test sample size. The sample ratios used in our model training:
- Test: 15% of full dataset:
- Validation: 20% of remaining samples (i.e. 17% of full dataset)
- Remaining images used for training: 64%
We're using stratification to maintain an around a relatively representative distribution for all classes. The invalid images found during our analysis are also excluded
importlib.reload(process_data_set)
process_data_set.stratified_split(
process_data_set.UNPROC_DATASET_LOC, process_data_set.OUTPUT_NAME, 0.15, invalid_images=invalid_image_paths
)
dataset_summary_df = process_data_set.verify_dataset(process_data_set.OUTPUT_NAME)
Verification Passed: No overlapping files between train and test sets.
count_columns = ["Total Samples", "Training Samples", "Testing Samples"]
melted_df = dataset_summary_df.reset_index().melt(id_vars=["index"], value_vars=count_columns, var_name="Sample Type",
value_name="Count")
melted_df = melted_df.rename(columns={"index": "Class"})
plt.figure(figsize=(12, 6))
sns.barplot(x="Class", y="Count", hue="Sample Type", data=melted_df)
plt.title('Distribution of Samples per Class')
plt.xticks(rotation=45)
plt.show()
dataset_summary_df
| Training Samples | Training Proportion (%) | Testing Samples | Testing Proportion (%) | Total Samples | |
|---|---|---|---|---|---|
| Boletus | 909 | 84.95% | 161 | 15.05% | 1070 |
| Suillus | 264 | 84.89% | 47 | 15.11% | 311 |
| Cortinarius | 709 | 85.01% | 125 | 14.99% | 834 |
| Russula | 972 | 85.04% | 171 | 14.96% | 1143 |
| Agaricus | 298 | 84.90% | 53 | 15.10% | 351 |
| Amanita | 636 | 85.03% | 112 | 14.97% | 748 |
| Entoloma | 309 | 84.89% | 55 | 15.11% | 364 |
| Hygrocybe | 268 | 85.08% | 47 | 14.92% | 315 |
| Lactarius | 1274 | 84.99% | 225 | 15.01% | 1499 |